"""
一个用来可视化OpenSeesPy模型的类
作者: 闫业祥
微信: yx592715024

"""
import os
import sys
import pyvista as pv

sys.path.append(os.path.abspath(os.path.dirname(__file__)))
from _ModelBase import _check_file
from _ModelBase import *


def _generate_mesh(points, cells, kind='line'):
    """
    generate the mesh from the points and cells
    """
    if kind == 'line':
        pltr = pv.PolyData()
        pltr.points = points
        pltr.lines = np.array(cells)
    elif kind == 'face':
        pltr = pv.PolyData()
        pltr.points = points
        pltr.faces = np.hstack(cells)
    else:
        raise
    return pltr


class OpenSeesVis(model_base):
    def __init__(self, point_size: float = 6, line_width: float = 6, colors_dict: dict = None,
                 theme: str = 'paraview', notebook: bool = False):
        """
        :param point_size: 节点尺寸;
        :param line_width:  线单元宽度;
        :param colors_dict:  颜色字典;
        :param theme: 可视化主题, 可选有'default', 'paraview', 'document', 'dark',
        :param notebook: 是否嵌入在notebook中可视化
        :return: None
        """
        super().__init__()
        # ------------------------------
        self.point_size = point_size
        self.line_width = line_width
        self.notebook = notebook
        # 初始化颜色字典
        colors = dict(point='#840000', line='#0165fc', face='#06c2ac', solid='#f48924',
                      truss="#7552cc", link="#00c16e")
        if colors_dict is not None:
            colors.update(colors_dict)
        self.color_point = colors['point']
        self.color_line = colors['line']
        self.color_face = colors['face']
        self.color_solid = colors['solid']
        self.color_truss = colors['truss']
        self.color_link = colors['link']
        # -------------------------------------------------
        pv.set_plot_theme(theme)
        # -------------------------------------------------
        self.model_data_finished = False  # 是否已提取模型数据
        self.mode_tag = None

    def model_vis(self, open_file: str = None, node_label: bool = True,
                  ele_label: bool = True, outline: bool = True, opacity: float = 0.75):
        """
        本函数用来可视化模型, 
        :param open_file: 打开已保存的 hdf5 文件名, 必须以.hdf5结尾, 如果为None, 则从当前域中获取数据
        :param node_label: bool, 是否显示节点编号
        :param ele_label:  bool, 是否显示单元编号
        :param outline:  bool, 是否显示坐标轴框
        :param opacity:  float, 面单元和体单元透明度
        :return: None
        """

        if open_file:
            _check_file(open_file, '.hdf5')
            with h5py.File(open_file, 'r') as f:
                model_info = dict()
                for name in self.model_info_names:
                    model_info[name] = f["ModelInfo"][name][...]
                cells = dict()
                for name in self.cells_names:
                    cells[name] = f["Cell"][name][...]
        else:
            if not self.model_data_finished:
                self.get_model_data()
            model_info = self.model_info
            cells = self.cells

        # ---------------------------------------------------
        plotter = pv.Plotter(notebook=self.notebook)

        point_plot = pv.PolyData(model_info['coord_no_deform'])
        plotter.add_mesh(point_plot, color=self.color_point,
                         point_size=self.point_size, render_points_as_spheres=True)

        if len(cells['truss']) > 0:
            truss_plot = _generate_mesh(model_info['coord_no_deform'], cells['truss'], kind='line')
            plotter.add_mesh(truss_plot, color=self.color_truss,
                             render_lines_as_tubes=True, line_width=self.line_width)

        if len(cells['link']) > 0:
            link_plot = _generate_mesh(model_info['coord_no_deform'], cells['link'], kind='line')
            plotter.add_mesh(link_plot, color=self.color_link,
                             render_lines_as_tubes=False, line_width=self.line_width / 5)

        if len(cells['beam']) > 0:
            beam_plot = _generate_mesh(model_info['coord_no_deform'], cells['beam'], kind='line')
            plotter.add_mesh(beam_plot, color=self.color_line,
                             render_lines_as_tubes=False, line_width=self.line_width / 5)

        if len(cells['other_line']) > 0:
            other_line_plot = _generate_mesh(model_info['coord_no_deform'], cells['other_line'], kind='line')
            plotter.add_mesh(other_line_plot, color=self.color_line,
                             render_lines_as_tubes=True, line_width=self.line_width)

        if len(cells['plane']) > 0:
            face_plot = _generate_mesh(model_info['coord_no_deform'], cells['plane'], kind='face')
            plotter.add_mesh(face_plot, color=self.color_face, show_edges=True, opacity=opacity)

        if len(cells['tetrahedron']) > 0:
            tet_plot = _generate_mesh(model_info['coord_no_deform'], cells['tetrahedron'], kind='face')
            plotter.add_mesh(tet_plot, color=self.color_solid, show_edges=True, opacity=opacity)

        if len(cells['brick']) > 0:
            bri_plot = _generate_mesh(model_info['coord_no_deform'], cells['brick'], kind='face')
            plotter.add_mesh(bri_plot, color=self.color_solid, show_edges=True, opacity=opacity)

        plotter.add_text('OpenSees 3D View', position='upper_left',
                         font_size=15, color='black', font='courier', viewport=True)
        plotter.add_text('Num. of Node: {0} \n Num. of Ele:{1}'
                         .format(model_info['num_node'], model_info['num_ele']),
                         position='upper_right', font_size=10, color='black', font='courier')
        if outline:
            plotter.show_bounds(grid=False, location='outer', bounds=model_info['bound'],
                                show_zaxis=True, color="black")
        if node_label:
            node_labels = [str(i) for i in model_info['NodeTags']]
            plotter.add_point_labels(model_info['coord_no_deform'], node_labels, text_color='white',
                                     font_size=12, bold=False, always_visible=True)
        if ele_label:
            ele_labels = [str(i) for i in model_info['EleTags']]
            plotter.add_point_labels(model_info['coord_ele_midpoints'], ele_labels, text_color='#ff796c',
                                     font_size=12, bold=False, always_visible=True)
        plotter.add_axes(color='black')
        plotter.view_isometric()
        if np.max(model_info['model_dims']) <= 2:
            plotter.view_xy(negative=False)
        plotter.show()
        plotter.close()

    def _generate_all_mesh(self, plotter, points, scalars, opacity, colormap,
                           lines_cells, face_cells, show_origin=False, points_origin=None):
        """
        生成所有mesh的辅助函数
        """
        clim = [np.min(scalars), (np.max(scalars))]

        point_plot = pv.PolyData(points)
        plotter.add_mesh(point_plot, colormap=colormap, scalars=scalars, clim=clim,
                         interpolate_before_map=True, point_size=self.point_size,
                         render_points_as_spheres=True, show_scalar_bar=False)

        if len(lines_cells) > 0:
            if show_origin:
                line_plot_origin = _generate_mesh(points_origin, lines_cells, kind='line')
                plotter.add_mesh(line_plot_origin, color='gray', line_width=self.line_width / 3,
                                 show_scalar_bar=False)
            line_plot = _generate_mesh(points, lines_cells, kind='line')
            plotter.add_mesh(line_plot, colormap=colormap, scalars=scalars,
                             interpolate_before_map=True, clim=clim, show_scalar_bar=False,
                             render_lines_as_tubes=True, line_width=self.line_width)
        else:
            line_plot = None

        if len(face_cells) > 0:
            if show_origin:
                face_plot_origin = _generate_mesh(points_origin, face_cells, kind='face')
                plotter.add_mesh(face_plot_origin, color='gray', style='wireframe', show_scalar_bar=False,
                                 show_edges=True, line_width=self.line_width / 3)
            face_plot = _generate_mesh(points, face_cells, kind='face')
            plotter.add_mesh(face_plot, colormap=colormap, scalars=scalars, clim=clim,
                             show_edges=True, opacity=opacity,
                             interpolate_before_map=True, show_scalar_bar=False)
        else:
            face_plot_origin = None
            face_plot = None

        return point_plot, line_plot, face_plot

    def eigen_vis(self, open_file: bool = None, mode_tag: int = 1, alpha: float = None,
                  gridspec: bool = True, outline: bool = False, show_origin: bool = False,
                  colormap: str = 'coolwarm', opacity: float = 1):
        """
        :param open_file: bool, 打开已保存的 hdf5 文件名, 必须以.hdf5结尾
        :param mode_tag: int, 模态号
        :param alpha: float, 缩放因子, 默认值根据最大变形取模型边界的1/5
        :param gridspec: bool, 是否显示多重模态图
        :param outline: bool, 是否显示坐标轴
        :param show_origin: bool, 是否显示原始形状
        :param colormap: 云图类型, 常用有 jet, rainbow, hot, afmhot, copper,
                            winter, cool, coolwarm, gist_earth, bone, binary, gray
        :param opacity: float, 面和体单元的透明度
        """

        if mode_tag == 1:
            gridspec = False

        if open_file:
            _check_file(open_file, '.hdf5')
            with h5py.File(open_file, 'r') as f:
                eigen_data = dict()
                for name in self.eigen_names:
                    eigen_data[name] = f["EigenInfo"][name][...]
        else:
            if gridspec:
                self.get_eigen_data(mode_tag, serialize=True)
            else:
                self.get_eigen_data(mode_tag, serialize=False)
            eigen_data = self.eigen

        f = eigen_data['f']
        eigenvector = eigen_data['eigenvector']
        if isinstance(f, list) or  isinstance(f, np.ndarray):
            num_mode_tag = len(f)
        else:
            num_mode_tag = 1

        if gridspec:
            shape_dict = {1: (1, 1), 2: (1, 2), 3: (1, 3), 4: (2, 2), 5: (2, 3), 6: (2, 3),
                          7: (3, 3), 8: (3, 3), 9: (3, 3), 10: (3, 4), 11: (3, 4), 12: (3, 4),
                          13: (4, 4), 14: (4, 4), 15: (4, 4), 16: (4, 4), 17: (4, 5), 18: (4, 5),
                          19: (4, 5), 20: (4, 5), 21: (5, 5), 22: (5, 5), 23: (5, 5), 24: (5, 5),
                          25: (5, 5), 26: (5, 6), 27: (5, 6), 28: (5, 6), 29: (5, 6), 30: (5, 6),
                          31: (6, 6), 32: (6, 6), 33: (6, 6), 34: (6, 6), 35: (6, 6), 36: (6, 6)}
            if mode_tag > 36:
                raise ValueError("When gridspec is True, mode_tag must < 36 for clarify")
            if mode_tag > num_mode_tag:
                raise ValueError("Insufficient number of modes in open file")
            shape = shape_dict[mode_tag]
            plotter = pv.Plotter(notebook=self.notebook, shape=shape)
            for i in range(mode_tag):
                eigen_vec = eigenvector[i]
                if alpha is None:
                    alpha_ = eigen_data['max_bound'] / 10 / np.max(np.sqrt(np.sum(eigen_vec ** 2, axis=1)))
                else:
                    alpha_ = alpha
                eigen_points = eigen_data['coord_no_deform'] + eigen_vec * alpha_
                scalars = np.sqrt(np.sum(eigen_vec ** 2, axis=1))

                idxi = int(np.ceil((i + 1) / shape[1]) - 1)
                idxj = int(i - idxi * shape[1])

                # ------------------------------------------------------
                plotter.subplot(idxi, idxj)

                point_plot, line_plot, face_plot = \
                    self._generate_all_mesh(plotter, eigen_points, scalars, opacity, colormap,
                                            eigen_data['all_lines'], eigen_data['all_faces'],
                                            show_origin=show_origin, points_origin=eigen_data['coord_no_deform'])

                # plotter.add_scalar_bar(color='#000000', n_labels=10, label_font_size=8)

                plotter.add_text('Mode {}\nf = {:.3f} Hz\nT = {:.3f} s'
                                 .format(i + 1, f[i], 1 / f[i]),
                                 position='upper_right', font_size=10, color='black', font='courier')
                if outline:
                    plotter.show_bounds(grid=False, location='outer', bounds=eigen_data['bound'],
                                        show_zaxis=True, color="black")
                plotter.add_axes(color='black')

                plotter.link_views()

        else:
            if num_mode_tag > 1:
                if mode_tag > num_mode_tag:
                    raise ValueError("Insufficient number of modes in open file")
                eigen_vec = eigenvector[mode_tag - 1]
                f_ = f[mode_tag - 1]
            else:
                eigen_vec = eigenvector
                f_ = f
            if alpha is None:
                alpha_ = eigen_data['max_bound'] / 10 / np.max(np.sqrt(np.sum(eigen_vec ** 2, axis=1)))
            else:
                alpha_ = alpha
            eigen_points = eigen_data['coord_no_deform'] + eigen_vec * alpha_
            scalars = np.sqrt(np.sum(eigen_vec ** 2, axis=1))
            plotter = pv.Plotter(notebook=self.notebook)
            point_plot, line_plot, face_plot = \
                self._generate_all_mesh(plotter, eigen_points, scalars, opacity, colormap,
                                        eigen_data['all_lines'], eigen_data['all_faces'],
                                        show_origin=show_origin, points_origin=eigen_data['coord_no_deform'])

            plotter.add_scalar_bar(color='#000000', n_labels=10, label_font_size=12)

            plotter.add_text('Mode {}\nf = {:.3f} Hz\nT = {:.3f} s'
                             .format(mode_tag, f_, 1 / f_),
                             position='upper_right', font_size=12, color='black', font='courier')
            if outline:
                plotter.show_bounds(grid=False, location='outer', bounds=eigen_data['bound'],
                                    show_zaxis=True, color="black")
            plotter.add_axes(color='black')
        plotter.view_isometric()
        if np.max(eigen_data['model_dims']) <= 2:
            plotter.view_xy(negative=False)
        plotter.show()
        plotter.close()

    def eigen_animation(self, output_file: str = None, open_file: str = None,
                        mode_tag: int = 1, alpha: float = None, outline: bool = False,
                        colormap: str = 'coolwarm', opacity: float = 1, framerate: int = 4):
        """
        :param output_file: str, 输出文件字符串, 必须以 .gif 或 .mp4 结束, 当为.mp4时, framerate 有效
        :param open_file: 打开已保存的 hdf5 文件名, 必须以.hdf5结尾
        :param framerate: str, 每秒演示的帧数, 越小越慢, 仅对 output_file 以 .mp4 结束时有效
        :param mode_tag: int, 模态号
        :param alpha: float, 缩放因子, 默认值根据最大变形取模型边界的1/5
        :param outline: bool, 是否显示坐标轴
        :param colormap: 云图类型, 常用有 jet, rainbow, hot, afmhot, copper,
                            winter, cool, coolwarm, gist_earth, bone, binary, gray
        :param opacity: float, 面和体单元的透明度
        """
        if output_file is not None:
            if output_file.endswith('.gif') or output_file.endswith('.mp4'):
                pass
            else:
                raise ValueError('output_file must be endwith .gif or .mp4')
        if open_file:
            _check_file(open_file, '.hdf5')
            with h5py.File(open_file, 'r') as f:
                eigen_data = dict()
                for name in self.eigen_names:
                    eigen_data[name] = f["EigenInfo"][name][...]
        else:
            self.get_eigen_data(mode_tag, serialize=False)
            eigen_data = self.eigen

        f = eigen_data['f']
        eigenvector = eigen_data['eigenvector']
        if isinstance(f, list):
            num_mode_tag = len(f)
        else:
            num_mode_tag = 1

        plotter = pv.Plotter(notebook=self.notebook)

        if num_mode_tag > 1:
            if mode_tag > num_mode_tag:
                raise ValueError("Insufficient number of modes in open file")
            eigen_vec = eigenvector[mode_tag - 1]
            f_ = f[mode_tag - 1]
        else:
            eigen_vec = eigenvector
            f_ = f
        if alpha is None:
            alpha_ = eigen_data['max_bound'] / 10 / np.max(np.sqrt(np.sum(eigen_vec ** 2, axis=1)))
        else:
            alpha_ = alpha
        eigen_points = eigen_data['coord_no_deform'] + eigen_vec * alpha_
        anti_eigen_points = eigen_data['coord_no_deform'] - eigen_vec * alpha_
        scalars = np.sqrt(np.sum(eigen_vec ** 2, axis=1))
        plotter = pv.Plotter(notebook=self.notebook)
        point_plot, line_plot, face_plot = \
            self._generate_all_mesh(plotter, eigen_points, scalars, opacity, colormap,
                                    eigen_data['all_lines'], eigen_data['all_faces'],
                                    show_origin=False, points_origin=eigen_data['coord_no_deform'])

        plotter.add_scalar_bar(color='#000000', n_labels=10, label_font_size=12)

        plotter.add_text('Mode {}\nf = {:.3f} Hz\nT = {:.3f} s'
                         .format(mode_tag, f_, 1 / f_),
                         position='upper_right', font_size=12, color='black', font='courier')
        if outline:
            plotter.show_bounds(grid=False, location='outer', bounds=eigen_data['bound'],
                                show_zaxis=True, color="black")
        plotter.add_axes(color='black')
        # plotter.add_text('OpenSees 3D View', position='upper_left', font_size=16, color='black', font='courier')
        plotter.view_isometric()
        if np.max(eigen_data['model_dims']) <= 2:
            plotter.view_xy(negative=False)

        # 动画
        # ----------------------------------------------------------------------------------
        if output_file.endswith('.gif'):
            plotter.open_gif(output_file)
        else:
            plotter.open_movie(output_file, framerate=framerate)
        plt_points = [anti_eigen_points, eigen_data['coord_no_deform'], eigen_points]
        index = [1, 2, 1, 0] * 10
        for idx in index:
            points = plt_points[idx]
            xyz = (eigen_data['coord_no_deform'] - points) / alpha_
            xyz_eigen = np.sqrt(np.sum(xyz ** 2, axis=1))
            plotter.update_coordinates(points, mesh=point_plot, render=False)
            if line_plot is not None:
                plotter.update_coordinates(points, mesh=line_plot, render=False)
                plotter.update_scalars(xyz_eigen, mesh=line_plot, render=False)
            if face_plot is not None:
                plotter.update_coordinates(points, mesh=face_plot, render=False)
                plotter.update_scalars(xyz_eigen, mesh=face_plot, render=False)
            plotter.update_scalar_bar_range(clim=[np.min(xyz_eigen), np.max(xyz_eigen)], name=None)
            plotter.write_frame()
        # ----------------------------------------------------------------------------------
        plotter.show()
        plotter.close()

    def deform_vis(self, step_tag: bool = -1, open_file: str = None, show_variable: str = 'disp',
                   alpha: float = None, outline: bool = False, show_origin: bool = False,
                   colormap: str = 'coolwarm', opacity: float = 1):
        """
        :param step_tag: int, 要可视化的分析步号
        :param open_file: str, 已保存的 hdf5 数据文件名
        :param show_variable: str, 显示的变量, 'disp', 'vel', 'accel
        :param alpha: float, 缩放因子, 默认值根据最大变形取模型边界的1/5
        :param outline: bool, 是否显示坐标轴
        :param show_origin: bool, 是否显示未变形形状
        :param colormap: 云图类型, 常用有 jet, rainbow, hot, afmhot, copper,
                            winter, cool, coolwarm, gist_earth, bone, binary, gray
        :param opacity: float, 面和体单元的透明度
        """

        if open_file is not None:
            if not open_file.endswith('.hdf5'):
                raise ValueError('open_file must be endswith .hdf5')

        if open_file:
            NodeRespSteps = {}
            CellSteps = {}
            ModelInfoSteps = {}
            with h5py.File(open_file, 'r') as f:
                for name in self.node_resp_names:
                    NodeRespSteps[name] = f["NodeRespSteps"][name][...]
                for name in self.cells_names:
                    CellSteps[name] = f["CellSteps"][name][...]
                for name in self.model_info_names:
                    ModelInfoSteps[name] = f["ModelInfoSteps"][name][...]
        else:
            if self.step_track == 0:
                raise ValueError('method get_step_results() was not invoked!')
            NodeRespSteps = self.node_resp_steps
            CellSteps = self.cells_steps
            ModelInfoSteps = self.model_info_steps

        step_num = self.step_track
        if step_tag == -1:
            step_tag = step_num
        step = step_tag - 1

        node_disp = NodeRespSteps['disp'][step]
        node_vel = NodeRespSteps['vel'][step]
        node_accel = NodeRespSteps['accel'][step]
        if self.model_update:
            node_nodeform_coords = ModelInfoSteps['coord_no_deform'][0]
            bounds = ModelInfoSteps['bound'][step]
            model_dims = ModelInfoSteps['model_dims'][step]
            lines_cells = CellSteps['all_lines'][step]
            faces_cells = CellSteps['all_faces'][step]
        else:
            node_nodeform_coords = ModelInfoSteps['coord_no_deform']
            bounds = ModelInfoSteps['bound']
            model_dims = ModelInfoSteps['model_dims']
            lines_cells = CellSteps['all_lines']
            faces_cells = CellSteps['all_faces']

        if alpha is None:
            max_bound = np.max([bounds[1] - bounds[0], bounds[3] - bounds[2], bounds[5] - bounds[4]])
            alpha_ = max_bound / 10 / np.max(np.sqrt(np.sum(node_disp ** 2, axis=1)))
        else:
            alpha_ = alpha
        node_deform_coords = alpha_ * node_disp + node_nodeform_coords

        # 绘图
        plotter = pv.Plotter(notebook=self.notebook)

        if show_variable.lower() == 'disp':
            scalars = np.sqrt(np.sum(node_disp ** 2, axis=1))
            data = node_disp
        elif show_variable.lower() == 'vel':
            scalars = np.sqrt(np.sum(node_vel ** 2, axis=1))
            data = node_vel
        elif show_variable.lower() == 'accel':
            scalars = np.sqrt(np.sum(node_accel ** 2, axis=1))
            data = node_accel
        else:
            raise ValueError('show_variable must be disp, vel, or accel!')

        point_plot, line_plot, face_plot = \
            self._generate_all_mesh(plotter, node_deform_coords, scalars, opacity, colormap,
                                    lines_cells, faces_cells,
                                    show_origin=show_origin, points_origin=node_nodeform_coords)

        plotter.add_scalar_bar(color='#000000', n_labels=10, label_font_size=16, title=show_variable)

        plotter.add_text('OpenSees 3D View', position='upper_left', shadow=True,
                         font_size=16, color='black', font='courier')
        plotter.add_text('peak of {}, step: {}\n'
                         'min.x = {:.3f}  max.x = {:.3f}\n'
                         'min.y = {:.3f}  max.y = {:.3f}\n'
                         'min.z = {:.3f}  max.z = {:.3f}\n'.format(show_variable, step + 1,
                                                                   np.min(data[:, 0]), np.max(data[:, 0]),
                                                                   np.min(data[:, 1]), np.max(data[:, 1]),
                                                                   np.min(data[:, 2]), np.max(data[:, 2])),
                         position='upper_right', shadow=True,
                         font_size=12, color='black', font='courier')

        if outline:
            plotter.show_bounds(grid=False, location='outer', bounds=bounds,
                                show_zaxis=True, color="black")
        plotter.add_axes(color='black')
        plotter.view_isometric()
        if np.max(model_dims) <= 2:
            plotter.view_xy(negative=False)
        plotter.show()
        plotter.close()

    def deform_animation(self, output_file: str, open_file: str = None, show_variable: str = 'disp',
                         alpha: float = None, outline: bool = False, colormap: str = 'coolwarm',
                         opacity: float = 1, framerate: int = 24):
        """
        :param output_file: str, 输出的动图文件名, 必须以 .gif 或 .mp4 结束
        :param open_file: str, 已保存的 hdf5 数据文件名
        :param show_variable: str, 显示的变量, 'disp', 'vel', 'accel
        :param alpha: float, 缩放因子, 默认值根据最大变形取模型边界的1/5
        :param outline: bool, 是否显示坐标轴
        :param colormap: 云图类型, 常用有 jet, rainbow, hot, afmhot, copper,
                            winter, cool, coolwarm, gist_earth, bone, binary, gray
        :param opacity: float, 面和体单元的透明度
        :param framerate: int, 每秒显示的帧数, output_file为mp4时有效, 值越大文件越小, 刷新率越快
        """
        if output_file.endswith('.gif') or output_file.endswith('.mp4'):
            pass
        else:
            raise ValueError('output_file must be endwith .gif or .mp4')

        if open_file is not None:
            if not open_file.endswith('.hdf5'):
                raise ValueError('open_file must be endswith .hdf5')

        if open_file:
            NodeRespSteps = {}
            CellSteps = {}
            ModelInfoSteps = {}
            with h5py.File(open_file, 'r') as f:
                for name in self.node_resp_names:
                    NodeRespSteps[name] = f["NodeRespSteps"][name][...]
                for name in self.cells_names:
                    CellSteps[name] = f["CellSteps"][name][...]
                for name in self.model_info_names:
                    ModelInfoSteps[name] = f["ModelInfoSteps"][name][...]
        else:
            if self.step_track == 0:
                raise ValueError('method get_step_results() was not invoked!')
            NodeRespSteps = self.node_resp_steps
            CellSteps = self.cells_steps
            ModelInfoSteps = self.model_info_steps

        if self.model_update:
            max_bound = ModelInfoSteps['max_bound'][0]
        else:
            max_bound = ModelInfoSteps['max_bound']
        if alpha is None:
            alphas = []
            for node_disp in NodeRespSteps['disp']:
                alphas.append(max_bound / 10 / np.max(np.sqrt(np.sum(node_disp ** 2, axis=1))))
            alpha_ = np.min(alphas)
        else:
            alpha_ = alpha

        # 初始化绘图
        # ----------------------------------------------------------------------------------
        plotter = pv.Plotter(notebook=self.notebook)

        if self.model_update:
            node_nodeform_coords = ModelInfoSteps['coord_no_deform'][0]
            bounds = ModelInfoSteps['bound'][0]
            model_dims = ModelInfoSteps['model_dims'][0]
            lines_cells = CellSteps['all_lines'][0]
            faces_cells = CellSteps['all_faces'][0]
        else:
            node_nodeform_coords = ModelInfoSteps['coord_no_deform']
            bounds = ModelInfoSteps['bound']
            model_dims = ModelInfoSteps['model_dims']
            lines_cells = CellSteps['all_lines']
            faces_cells = CellSteps['all_faces']
        scalars = node_nodeform_coords[:, 0] * 0

        point_plot, line_plot, face_plot = \
            self._generate_all_mesh(plotter, node_nodeform_coords, scalars, opacity, colormap,
                                    lines_cells, faces_cells,
                                    show_origin=False, points_origin=None)

        # plotter.add_text('OpenSees 3D View', position='upper_left', font_size=16, color='black', font='courier')
        plotter.add_scalar_bar(color='#000000', n_labels=10, label_font_size=16, title=show_variable)
        if outline:
            plotter.show_bounds(grid=False, location='outer', bounds=bounds, show_zaxis=True,
                                color="black")
        plotter.add_axes(color='black')
        plotter.view_isometric()
        if np.max(model_dims) <= 2:
            plotter.view_xy(negative=False)

        # 动画更新
        # -----------------------------------------------------------------------------
        if output_file.endswith('.gif'):
            plotter.open_gif(output_file)
        else:
            plotter.open_movie(output_file, framerate=framerate, quality=5)
        step_num = len(NodeRespSteps['disp'])
        for step in range(step_num):
            node_disp = NodeRespSteps['disp'][step]
            node_vel = NodeRespSteps['vel'][step]
            node_accel = NodeRespSteps['accel'][step]
            if self.model_update:
                lines_cells = CellSteps['all_lines'][step]
                faces_cells = CellSteps['plane'][step]
            else:
                lines_cells = CellSteps['all_lines']
                faces_cells = CellSteps['all_faces']

            node_deform_coords = alpha_ * node_disp + node_nodeform_coords

            if show_variable.lower() == 'disp':
                scalars = np.sqrt(np.sum(node_disp ** 2, axis=1))
                data = node_disp
            elif show_variable.lower() == 'vel':
                scalars = np.sqrt(np.sum(node_vel ** 2, axis=1))
                data = node_vel
            elif show_variable.lower() == 'accel':
                scalars = np.sqrt(np.sum(node_accel ** 2, axis=1))
                data = node_accel
            else:
                raise ValueError('show_variable must be disp, vel, or accel!')

            plotter.update_coordinates(node_deform_coords, mesh=point_plot, render=False)
            if line_plot is not None:
                plotter.update_coordinates(node_deform_coords, mesh=line_plot, render=False)
                plotter.update_scalars(scalars, mesh=line_plot, render=False)
            if face_plot is not None:
                plotter.update_coordinates(node_deform_coords, mesh=face_plot, render=False)
                plotter.update_scalars(scalars, mesh=face_plot, render=False)
            plotter.update_scalar_bar_range(clim=[np.min(scalars), np.max(scalars)])

            txt = plotter.add_text('peak of {}, step: {}\n'
                                   'min.x = {:.3f}  max.x = {:.3f}\n'
                                   'min.y = {:.3f}  max.y = {:.3f}\n'
                                   'min.z = {:.3f}  max.z = {:.3f}\n'
                                   .format(show_variable, step + 1,
                                           np.min(data[:, 0]), np.max(data[:, 0]),
                                           np.min(data[:, 1]), np.max(data[:, 1]),
                                           np.min(data[:, 2]), np.max(data[:, 2])),
                                   position='upper_right', font_size=12, color='black', font='courier')
            plotter.write_frame()
            if step < step_num - 1:
                plotter.remove_actor(txt)

        plotter.show()
        plotter.close()


if __name__ == "__main__":
    from CableStayedBridge import *

    A = OpenSeesVis(theme='document', line_width=2, point_size=2)
    A.model_vis(node_label=False, ele_label=False)
    A.eigen_vis(mode_tag=6, gridspec=True, colormap='coolwarm', show_origin=False)
    A.eigen_animation(output_file=r'C:\Users\yan\Desktop\yan666.gif', mode_tag=1, alpha=None, outline=False,
                      colormap='coolwarm', opacity=1)
